Implement BroadcastAxes::vmap#3319
Conversation
zcbenz
left a comment
There was a problem hiding this comment.
Can you explain how you are handling broadcasting with vmap? The code looks very confusing.
The
After that, I can also add a few comments in the implementation to make that flow easier to follow. |
|
Well that is just a rephrasing of what the code does. To be honest I don't know if you actually understood the algorithm and explained it to me, or just pasted what AI said. I'm leaving this to @angeloskath to judge. |
angeloskath
left a comment
There was a problem hiding this comment.
Thanks for the PR but this is not quite correct.
- The tests are not touching the BroadcastAxes primitive at all since it will only appear when tracing ie using compile.
- In addition, the main point of this primitive is to make it easier to do shapeless compilation with partial broadcasting so no reshape you can use the expand dims op which uses a primitive for the same reason.
- Finally, and this can be personal preference, I think it is unnecessary complex in some places. E.g.:
a. Why addaxes[i] == 1tondim
b. Why not usemoveaxisinstead oftranspose
c. Complicated axis math to re-express the ignore axes argument vs keeping the negative and moving the batch axis first.
I will close this in favor of #3344 .
Summary
Implements
BroadcastAxes::vmapinmlx/primitives.cpp, replacing the NYI stub and enabling vmapped behavior forBroadcastAxes-backed ops.What changed
BroadcastAxes::vmapin C++ core.ignore_axes_to the aligned representation.Tests
C++
tests/vmap_tests.cpp:test vmap broadcast axes primitive{0,1}and{-1,0}) and output-axis expectations.Python
python/tests/test_vmap.py:test_vmap_broadcast_toin_axes=0,in_axes=1, andin_axes=-1, out_axes=-1.take_along_axisandput_along_axisvmap cases.Closes the NYI stub at
mlx/primitives.cpp:907.